{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <div style='color:#00c2ea;font-weight:800;'>模型训练参数设置</div>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nes_py.wrappers import JoypadSpace\n",
    "import gym_super_mario_bros\n",
    "from gym_super_mario_bros.actions import SIMPLE_MOVEMENT\n",
    "import time\n",
    "from matplotlib import pyplot as plt\n",
    "from gym.wrappers import GrayScaleObservation\n",
    "from stable_baselines3.common.monitor import Monitor\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv\n",
    "from stable_baselines3.common.vec_env import VecFrameStack\n",
    "import os\n",
    "from stable_baselines3 import PPO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym_super_mario_bros.make('SuperMarioBros-v0')\n",
    "env = JoypadSpace(env, SIMPLE_MOVEMENT)\n",
    "\n",
    "monitor_dir = r'./monitor_log/'\n",
    "os.makedirs(monitor_dir,exist_ok=True)\n",
    "env = Monitor(env,monitor_dir)\n",
    "\n",
    "env = GrayScaleObservation(env,keep_dim=True)\n",
    "env = DummyVecEnv([lambda: env])\n",
    "env = VecFrameStack(env,4,channels_order='last')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 一、直接设置参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cuda device\n",
      "Wrapping the env in a VecTransposeImage.\n"
     ]
    }
   ],
   "source": [
    "tensorboard_log = r'./tensorboard_log/'\n",
    "# learning_rate\n",
    "# n_steps\n",
    "model = PPO(\"CnnPolicy\", env, verbose=1,\n",
    "            tensorboard_log = tensorboard_log)\n",
    "# model.learn(total_timesteps=25000)\n",
    "# model.save(\"mario_model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dir(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0003\n",
      "2048\n"
     ]
    }
   ],
   "source": [
    "print(model.learning_rate)\n",
    "print(model.n_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cuda device\n",
      "Wrapping the env in a VecTransposeImage.\n"
     ]
    }
   ],
   "source": [
    "tensorboard_log = r'./tensorboard_log/'\n",
    "learning_rate = 1e-6\n",
    "n_steps = 128\n",
    "model = PPO(\"CnnPolicy\", env, verbose=1,\n",
    "            tensorboard_log = tensorboard_log,\n",
    "            learning_rate = learning_rate,\n",
    "            n_steps = n_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1e-06\n",
      "128\n"
     ]
    }
   ],
   "source": [
    "print(model.learning_rate)\n",
    "print(model.n_steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 二、通过字典设置参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_param_1={\n",
    "    'learning_rate' : 1e-8,\n",
    "    'n_steps' : 1024\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cuda device\n",
      "Wrapping the env in a VecTransposeImage.\n"
     ]
    }
   ],
   "source": [
    "tensorboard_log = r'./tensorboard_log/'\n",
    "\n",
    "model = PPO(\"CnnPolicy\", env, verbose=1,\n",
    "            tensorboard_log = tensorboard_log,\n",
    "            **model_param_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1e-08\n",
      "1024\n"
     ]
    }
   ],
   "source": [
    "print(model.learning_rate)\n",
    "print(model.n_steps)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "9465aae7e0ab1403d672807d1a0963d86dbda2f584fbe3054c36cf78311c6c77"
  },
  "kernelspec": {
   "display_name": "Python 3.8.11 ('pytorch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
